{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.seterr(divide='ignore', invalid='ignore')  #消除向量中除以0的警告\n",
    "# 获取数据\n",
    "def loadDataSet():\n",
    "    postingList = [['my', 'dog', 'has', 'flea', 'problems', 'help', 'please'],\n",
    "                   ['maybe', 'not', 'take', 'him', 'to', 'dog', 'park', 'stupid'],\n",
    "                   ['my', 'dalmation', 'is', 'so', 'cute', 'I', 'love', 'him'],\n",
    "                   ['stop', 'posting', 'stupid', 'worthless', 'garbage'],\n",
    "                   ['mr', 'licks', 'ate', 'my', 'steak', 'how', 'to', 'stop', 'him'],\n",
    "                   ['quit', 'buying', 'worthless', 'dog', 'food', 'stupid']]\n",
    "    classVec = [0, 1, 0, 1, 0, 1] #1表示侮辱性言论，0表示正常\n",
    "    return postingList, classVec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def createVocabList(dataSet):\n",
    "    vocabSet = set([])\n",
    "    for document in dataSet:\n",
    "        vocabSet = vocabSet | set(document)\n",
    "    return list(vocabSet)\n",
    "\n",
    "# 对输入的词汇表构建词向量\n",
    "def setOfWords2Vec(vocabList, inputSet):\n",
    "    returnVec = np.zeros(len(vocabList)) #生成零向量的array\n",
    "    for word in inputSet:\n",
    "        if word in vocabList:\n",
    "            returnVec[vocabList.index(word)] = 1 #有单词，该位置填充1\n",
    "        else:\n",
    "            print(\"the word: %s is not in my Vocabulary\" % word)\n",
    "            # pass\n",
    "    return returnVec  #返回0，1的向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "listPosts, listClasses = loadDataSet()\n",
    "myVocabList = createVocabList(listPosts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "32"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(myVocabList)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trainNB0(trainMatrix, trainCategory):\n",
    "    numTrainDocs = len(trainMatrix)  #文档数目\n",
    "    numWord = len(trainMatrix[0])  #词汇表数目\n",
    "    print(numTrainDocs, numWord)\n",
    "    pAbusive = sum(trainCategory) / len(trainCategory) #p1, 出现侮辱性评论的概率 [0, 1, 0, 1, 0, 1]\n",
    "    p0Num = np.zeros(numWord)\n",
    "    p1Num = np.zeros(numWord)\n",
    "\n",
    "    p0Demon = 0\n",
    "    p1Demon = 0\n",
    "\n",
    "    for i in range(numTrainDocs):\n",
    "        if trainCategory[i] == 0:\n",
    "            p0Num += trainMatrix[i] #向量相加\n",
    "            p0Demon += sum(trainMatrix[i]) #向量中1累加其和\n",
    "        else:\n",
    "            p1Num += trainMatrix[i]\n",
    "            p1Demon += sum(trainMatrix[i])\n",
    "    p0Vec = p0Num / p0Demon\n",
    "    p1Vec = p1Num / p1Demon\n",
    "\n",
    "    return p0Vec, p1Vec, pAbusive"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainMat = []\n",
    "for postinDoc in listPosts:\n",
    "    trainMat.append(setOfWords2Vec(myVocabList, postinDoc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([0., 1., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
       "        1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]),\n",
       " array([0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0.,\n",
       "        0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0.]),\n",
       " array([0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1.,\n",
       "        0., 0., 1., 0., 0., 1., 1., 1., 0., 1., 0., 0., 0., 0., 0.]),\n",
       " array([0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.,\n",
       "        0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]),\n",
       " array([1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,\n",
       "        0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 1., 1., 0., 1.]),\n",
       " array([0., 1., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.,\n",
       "        0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainMat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6 32\n"
     ]
    }
   ],
   "source": [
    "p0Vec, p1Vec, pAbusive = trainNB0(trainMat, listClasses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trainNB1(trainMatrix, trainCategory):\n",
    "    numTrainDocs = len(trainMatrix)  #文档数目\n",
    "    numWord = len(trainMatrix[0])  #词汇表数目\n",
    "    pAbusive = sum(trainCategory) / len(trainCategory) #p1, 出现侮辱性评论的概率\n",
    "    p0Num = np.ones(numWord)  #修改为1\n",
    "    p1Num = np.ones(numWord)\n",
    "\n",
    "    p0Demon = 2 #修改为2\n",
    "    p1Demon = 2\n",
    "\n",
    "    for i in range(numTrainDocs):\n",
    "        if trainCategory[i] == 0:\n",
    "            p0Num += trainMatrix[i] #向量相加\n",
    "            p0Demon += sum(trainMatrix[i]) #向量中1累加其和\n",
    "        else:\n",
    "            p1Num += trainMatrix[i]\n",
    "            p1Demon += sum(trainMatrix[i])\n",
    "    p0Vec = np.log(p0Num / p0Demon)  #求对数\n",
    "    p1Vec = np.log(p1Num / p1Demon)\n",
    "\n",
    "    return p0Vec, p1Vec, pAbusive\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "p0Vec, p1Vec, pAbusive = trainNB1(trainMat, listClasses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-2.56494936, -2.56494936, -2.56494936, -2.56494936, -2.56494936,\n",
       "       -2.56494936, -3.25809654, -2.56494936, -3.25809654, -3.25809654,\n",
       "       -3.25809654, -3.25809654, -2.56494936, -3.25809654, -2.56494936,\n",
       "       -2.56494936, -2.56494936, -2.56494936, -3.25809654, -2.56494936,\n",
       "       -3.25809654, -3.25809654, -2.15948425, -2.56494936, -2.56494936,\n",
       "       -3.25809654, -1.87180218, -2.56494936, -2.56494936, -2.56494936,\n",
       "       -3.25809654, -2.56494936])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p0Vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['love'] classified as 0\n"
     ]
    }
   ],
   "source": [
    "def classifyNB(vec2Classify, p0Vc,  p1Vc, pClass1):\n",
    "    p1 = sum(vec2Classify * p1Vc) * pClass1\n",
    "    p0 = sum(vec2Classify * p0Vc) * (1-pClass1)\n",
    "    # p1 = sum(vec2Classify * p1Vc) + np.log(pClass1)    #取对数，防止结果溢出\n",
    "    # p0 = sum(vec2Classify * p0Vc) + np.log(1 - pClass1)\n",
    "    if p1 > p0:\n",
    "        return 1\n",
    "    else:\n",
    "        return 0\n",
    "testEntry = ['love']\n",
    "thisDoc = setOfWords2Vec(myVocabList, testEntry)\n",
    "print(testEntry, 'classified as', classifyNB(thisDoc, p0Vec, p1Vec, pAbusive))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
